{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import print_function\n",
    "import os\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "from tensorflow.contrib.learn.python.learn.datasets import mnist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<dtype: 'float32'>\n",
      "(10,)\n",
      "(tf.float32, tf.int32)\n",
      "(TensorShape([]), TensorShape([Dimension(100)]))\n",
      "(tf.float32, (tf.float32, tf.int32))\n",
      "(TensorShape([Dimension(10)]), (TensorShape([]), TensorShape([Dimension(100)])))\n"
     ]
    }
   ],
   "source": [
    "dataset1 = tf.data.Dataset.from_tensor_slices(tf.random_uniform([4, 10]))\n",
    "print(dataset1.output_types)  # ==> \"tf.float32\"\n",
    "print(dataset1.output_shapes)  # ==> \"(10,)\"\n",
    "\n",
    "dataset2 = tf.data.Dataset.from_tensor_slices(\n",
    "   (tf.random_uniform([4]),\n",
    "    tf.random_uniform([4, 100], maxval=100, dtype=tf.int32)))\n",
    "print(dataset2.output_types)  # ==> \"(tf.float32, tf.int32)\"\n",
    "print(dataset2.output_shapes)  # ==> \"((), (100,))\"\n",
    "\n",
    "dataset3 = tf.data.Dataset.zip((dataset1, dataset2))\n",
    "print(dataset3.output_types)  # ==> (tf.float32, (tf.float32, tf.int32))\n",
    "print(dataset3.output_shapes)  # ==> \"(10, ((), (100,)))\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'a': tf.float32, 'b': tf.int32}\n",
      "{'a': TensorShape([]), 'b': TensorShape([Dimension(100)])}\n"
     ]
    }
   ],
   "source": [
    "dataset = tf.data.Dataset.from_tensor_slices(\n",
    "   {\"a\": tf.random_uniform([4]),\n",
    "    \"b\": tf.random_uniform([4, 100], maxval=100, dtype=tf.int32)})\n",
    "print(dataset.output_types)  # ==> \"{'a': tf.float32, 'b': tf.int32}\"\n",
    "print(dataset.output_shapes)  # ==> \"{'a': (), 'b': (100,)}\"\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## One-hot Iterator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "sess=tf.InteractiveSession()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = tf.data.Dataset.range(100)\n",
    "iterator = dataset.make_one_shot_iterator()\n",
    "next_element = iterator.get_next()\n",
    "\n",
    "for i in range(100):\n",
    "  value = sess.run(next_element)\n",
    "  assert i == value"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## initializable Iterator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_value = tf.placeholder(tf.int64, shape=[])\n",
    "dataset = tf.data.Dataset.range(max_value)\n",
    "iterator = dataset.make_initializable_iterator()\n",
    "next_element = iterator.get_next()\n",
    "\n",
    "sess.run(iterator.initializer, feed_dict={max_value: 10})\n",
    "for i in range(10):\n",
    "  value = sess.run(next_element)\n",
    "  assert i == value\n",
    "\n",
    "sess.run(iterator.initializer, feed_dict={max_value: 100})\n",
    "for i in range(100):\n",
    "  value = sess.run(next_element)\n",
    "  assert i == value"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## reinitializable Itarator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义训练集和验证集\n",
    "training_dataset = tf.data.Dataset.range(100).map(\n",
    "    lambda x: x + tf.random_uniform([], -10, 10, tf.int64))\n",
    "validation_dataset = tf.data.Dataset.range(50)\n",
    "\n",
    "# reinitializable的iterator需要定义它的结构，我们使用训练集的结构来定义它。当然也要求验证集是同样的结构。\n",
    "iterator = tf.data.Iterator.from_structure(training_dataset.output_types,\n",
    "                                           training_dataset.output_shapes)\n",
    "next_element = iterator.get_next()\n",
    "\n",
    "training_init_op = iterator.make_initializer(training_dataset)\n",
    "validation_init_op = iterator.make_initializer(validation_dataset)\n",
    "\n",
    "# 运行20个Epoch\n",
    "for _ in range(20):\n",
    "  # 首先用训练集来初始化Iterator\n",
    "  sess.run(training_init_op)\n",
    "  for _ in range(100):\n",
    "    sess.run(next_element)\n",
    "\n",
    "  # 再使用验证集来初始化Iterator\n",
    "  sess.run(validation_init_op)\n",
    "  for _ in range(50):\n",
    "    sess.run(next_element)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## feedable Iterator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-11-17ed142b48cd>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     26\u001b[0m   \u001b[0;31m# 因此下一次训练会接着从前面数据\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     27\u001b[0m   \u001b[0;32mfor\u001b[0m \u001b[0m_\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m200\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 28\u001b[0;31m     \u001b[0msess\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnext_element\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0mhandle\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mtraining_handle\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     29\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     30\u001b[0m   \u001b[0;31m# 验证集需要我们初始化\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/py3.6-env/lib/python3.6/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m    903\u001b[0m     \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    904\u001b[0m       result = self._run(None, fetches, feed_dict, options_ptr,\n\u001b[0;32m--> 905\u001b[0;31m                          run_metadata_ptr)\n\u001b[0m\u001b[1;32m    906\u001b[0m       \u001b[0;32mif\u001b[0m \u001b[0mrun_metadata\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    907\u001b[0m         \u001b[0mproto_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf_session\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTF_GetBuffer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrun_metadata_ptr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/py3.6-env/lib/python3.6/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m_run\u001b[0;34m(self, handle, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m   1135\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mfinal_fetches\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mfinal_targets\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mhandle\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mfeed_dict_tensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1136\u001b[0m       results = self._do_run(handle, final_targets, final_fetches,\n\u001b[0;32m-> 1137\u001b[0;31m                              feed_dict_tensor, options, run_metadata)\n\u001b[0m\u001b[1;32m   1138\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1139\u001b[0m       \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/py3.6-env/lib/python3.6/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m_do_run\u001b[0;34m(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m   1323\u001b[0m       \u001b[0;31m# pylint: enable=protected-access\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1324\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1325\u001b[0;31m       \u001b[0mfeeds\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcompat\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mas_bytes\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mv\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mv\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1326\u001b[0m       \u001b[0mfetches\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_name_list\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfetch_list\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1327\u001b[0m       \u001b[0mtargets\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_name_list\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtarget_list\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/py3.6-env/lib/python3.6/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m<genexpr>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m   1323\u001b[0m       \u001b[0;31m# pylint: enable=protected-access\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1324\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1325\u001b[0;31m       \u001b[0mfeeds\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcompat\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mas_bytes\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mv\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mv\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1326\u001b[0m       \u001b[0mfetches\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_name_list\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfetch_list\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1327\u001b[0m       \u001b[0mtargets\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_name_list\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtarget_list\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/py3.6-env/lib/python3.6/site-packages/tensorflow/python/framework/ops.py\u001b[0m in \u001b[0;36mname\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    318\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_op\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    319\u001b[0m       \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Operation was not named: %s\"\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_op\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 320\u001b[0;31m     \u001b[0;32mreturn\u001b[0m \u001b[0;34m\"%s:%d\"\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_op\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_value_index\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    321\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    322\u001b[0m   \u001b[0;34m@\u001b[0m\u001b[0mproperty\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "# 定义训练集和验证集\n",
    "# 这里的训练集后面使用了repeate，因此是可以无限词迭代的。\n",
    "training_dataset = tf.data.Dataset.range(100).map(\n",
    "    lambda x: x + tf.random_uniform([], -10, 10, tf.int64)).repeat()\n",
    "validation_dataset = tf.data.Dataset.range(50)\n",
    "\n",
    "# feedable iterator除了需要Dataset的结构外还需要一个string的placeholder\n",
    "handle = tf.placeholder(tf.string, shape=[])\n",
    "iterator = tf.data.Iterator.from_string_handle(\n",
    "    handle, training_dataset.output_types, training_dataset.output_shapes)\n",
    "next_element = iterator.get_next()\n",
    "\n",
    "# reinitializable只有一个Iterator，通过session.run(init_op)来切换不同数据集\n",
    "# 而feedable Iterator每个数据集有自己的Iterator，这样更加灵活，比如训练集我们是one-hot的Iterator\n",
    "# 而验证集使用initializable的Iterator\n",
    "training_iterator = training_dataset.make_one_shot_iterator()\n",
    "validation_iterator = validation_dataset.make_initializable_iterator()\n",
    "\n",
    "# 训练Iterator和验证Iterator都有一个string的handle，我们需要得到它，从而可以feed给feedable Iterator\n",
    "training_handle = sess.run(training_iterator.string_handle())\n",
    "validation_handle = sess.run(validation_iterator.string_handle())\n",
    "\n",
    "# 无限的循环，需要自己interrupt它\n",
    "while True:\n",
    "  # 训练集上运行200次训练。注意训练dataset是用repeat得到的可遍历无限次的Iteartor\n",
    "  # 因此下一次训练会接着从前面数据 \n",
    "  for _ in range(200):\n",
    "    sess.run(next_element, feed_dict={handle: training_handle})\n",
    "\n",
    "  # 验证集需要我们初始化\n",
    "  sess.run(validation_iterator.initializer)\n",
    "  for _ in range(50):\n",
    "    sess.run(next_element, feed_dict={handle: validation_handle})\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## get_next()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = tf.data.Dataset.range(5)\n",
    "iterator = dataset.make_initializable_iterator()\n",
    "next_element = iterator.get_next()\n",
    " \n",
    "result = tf.add(next_element, next_element)\n",
    "\n",
    "sess.run(iterator.initializer)\n",
    "print(sess.run(result))  # ==> \"0\"\n",
    "print(sess.run(result))  # ==> \"2\"\n",
    "print(sess.run(result))  # ==> \"4\"\n",
    "print(sess.run(result))  # ==> \"6\"\n",
    "print(sess.run(result))  # ==> \"8\"\n",
    "try:\n",
    "  sess.run(result)\n",
    "except tf.errors.OutOfRangeError:\n",
    "  print(\"End of dataset\")  # ==> \"End of dataset\"\n",
    "\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = tf.data.Dataset.range(5)\n",
    "iterator = dataset.make_initializable_iterator()\n",
    "next_element = iterator.get_next()\n",
    " \n",
    "result = tf.add(next_element, next_element)\n",
    "sess.run(iterator.initializer)\n",
    "while True:\n",
    "  try:\n",
    "    print(sess.run(result))\n",
    "  except tf.errors.OutOfRangeError:\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset1 = tf.data.Dataset.from_tensor_slices(tf.random_uniform([4, 10]))\n",
    "dataset2 = tf.data.Dataset.from_tensor_slices((tf.random_uniform([4]), tf.random_uniform([4, 100])))\n",
    "dataset3 = tf.data.Dataset.zip((dataset1, dataset2))\n",
    "\n",
    "iterator = dataset3.make_initializable_iterator()\n",
    "\n",
    "sess.run(iterator.initializer)\n",
    "next1, (next2, next3) = iterator.get_next()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 使用dataset api来读取tfrecord数据并进行训练"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "生成tfrecord文件，和前面一样"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "save_dir = \"../data\"\n",
    "# Download data to save_dir\n",
    "data_sets = mnist.read_data_sets(save_dir,\n",
    "                                 dtype=tf.uint8,\n",
    "                                 reshape=False,\n",
    "                                 validation_size=1000)\n",
    "\n",
    "data_splits = [\"train\", \"test\", \"validation\"]\n",
    "for d in range(len(data_splits)):\n",
    "    print(\"saving \" + data_splits[d])\n",
    "    data_set = data_sets[d]\n",
    "    filename = os.path.join(save_dir, data_splits[d] + '.tfrecords')\n",
    "    writer = tf.python_io.TFRecordWriter(filename)\n",
    "    print(data_set.images[0].shape, data_set.images[0].dtype)\n",
    "    for index in range(data_set.images.shape[0]):\n",
    "        image = data_set.images[index].tostring()\n",
    "        # image = data_set.images[index].reshape(-1).astype(np.int32)\n",
    "        example = tf.train.Example(features=tf.train.Features(feature={\n",
    "            'height': tf.train.Feature(int64_list=tf.train.Int64List(value=\n",
    "                                                                     [data_set.images.shape[1]])),\n",
    "            'width': tf.train.Feature(int64_list=tf.train.Int64List(value=\n",
    "                                                                    [data_set.images.shape[2]])),\n",
    "            'depth': tf.train.Feature(int64_list=tf.train.Int64List(value=\n",
    "                                                                    [data_set.images.shape[3]])),\n",
    "            'label': tf.train.Feature(int64_list=tf.train.Int64List(value=\n",
    "                                                                    [int(data_set.labels[index]), 2])),\n",
    "        # 为了演示，这里给label多加一个整数2\n",
    "            'image_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=\n",
    "                                                                        [image]))}))\n",
    "        # 'image_raw': tf.train.Feature(int64_list=tf.train.Int64List(value =\n",
    "        #        image))}))\n",
    "        writer.write(example.SerializeToString())\n",
    "    writer.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _parse_function(example_proto):\n",
    "    features = {\"image_raw\": tf.FixedLenFeature([], tf.string),\n",
    "                \"label\": tf.FixedLenFeature([2], tf.int64)}\n",
    "    parsed_features = tf.parse_single_example(example_proto, features)\n",
    "    image = tf.decode_raw(parsed_features['image_raw'], tf.uint8)\n",
    "    image.set_shape([784])\n",
    "    image = tf.cast(image, tf.float32) * (1. / 255) - 0.5\n",
    "\n",
    "    label = tf.cast(parsed_features['label'][0], tf.int32)\n",
    "    return image, label\n",
    "\n",
    "\n",
    "filenames = [\"../data/train.tfrecords\"]\n",
    "dataset = tf.data.TFRecordDataset(filenames)\n",
    "dataset = dataset.map(_parse_function)\n",
    "dataset = dataset.shuffle(buffer_size=10000)\n",
    "dataset = dataset.batch(128)\n",
    "dataset = dataset.repeat(10)\n",
    "iterator = dataset.make_initializable_iterator()\n",
    "images_batch, labels_batch = iterator.get_next()\n",
    "\n",
    "W = tf.get_variable(\"W\", [28 * 28, 10])\n",
    "y_pred = tf.matmul(images_batch, W)\n",
    "loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y_pred, labels=labels_batch)\n",
    "loss_mean = tf.reduce_mean(loss)\n",
    "train_op = tf.train.AdamOptimizer().minimize(loss)\n",
    "sess = tf.Session()\n",
    "init = tf.global_variables_initializer()\n",
    "sess.run(init)\n",
    "init = tf.local_variables_initializer()\n",
    "sess.run(init)\n",
    "sess.run(iterator.initializer)\n",
    "step = 0\n",
    "while True:\n",
    "    try:\n",
    "        step += 1\n",
    "        sess.run([train_op])\n",
    "        if step % 500 == 0:\n",
    "            loss_mean_val = sess.run([loss_mean])\n",
    "            print(step)\n",
    "            print(loss_mean_val)\n",
    "    except tf.errors.OutOfRangeError:\n",
    "        break\n",
    "\n",
    "sess.close()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py3.6-env",
   "language": "python",
   "name": "py3.6-env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
